Árvores de Classificação e Regressão

Classification And Regression Tree (CART) - Breiman et al. (1984)

Deivison Venicio Souza

27/outubro/2018

library(rmarkdown)
knitr::opts_chunk$set(fig.align="center", cache=F, prompt=FALSE, comment = NA, eval = TRUE)

1 Descrição da atividade

Vamos considerar a aplicação de uma árvore de regressão à base de dados abalone, do pacote AppliedPredictiveModeling. Os dados referem-se a 4177 espécimes de abalone, tipo de molusco encontrado ao longo das águas costeiras de todos os continentes. A variável resposta é a idade do molusco, aferida pelo número de anéis internos, que é um procedimento demorado e pouco adequado. O objetivo é ajustar um modelo que permita estimar a idade a partir de outras medidas, que são obtidas com maior facilidade. Para maiores detalhes a respeito da base, consultar a documentação e o link fornecido.

Para a análise, as primeiras 3000 linhas deverão ser usadas para ajuste, e as demais para validação.

  1. Qual o tamanho da árvore (número de nós finais) selecionada por validação cruzada? Quantas são as partições? Nota: Fixe a semente com set.seed(1). Estabeleça cp = 0.001 para o processo de poda.

  2. Quantas covariáveis aparecem no ajuste da árvore?

  3. Qual a idade estimada para moluscos com:
  1. ShellWeight=0.18 e ShuckedWeight=0.25;
  2. ShellWeight=0.31 e ShuckedWeight=0.45?
  1. Qual o resíduo para cada um dos dados? Considere, para o primeiro, Rings=8 e para o segundo Rings=10.

  2. Usando os dados de validação, calcule e apresente o valor da soma de quadrados de resíduos.

2 Carregando pacote

library(AppliedPredictiveModeling)          # contém o conjunto de dados "abalone"
library(PerformanceAnalytics)               # Gráfico da matriz de correlações
library(rpart)                              # Ajustar um modelo de árvore de regressão (CART)
library(rpart.plot)                         # Plotar uma árvore de regressão
library(xlsx)                               # Salvar para excel
library(ggplot2)                            # Vizualização gráfica
library(data.table)                         # manipulação de dados
library(gridExtra)                          # tabela em ggplot2
library(ggthemes)
#library(caret)     

3 Carregando dados e fazendo manipulações

?abalone
data(abalone)                                     # carrega conjunto de dados
str(abalone)                                      # estrutura do data frame
'data.frame':   4177 obs. of  9 variables:
 $ Type         : Factor w/ 3 levels "F","I","M": 3 3 1 3 2 2 1 1 3 1 ...
 $ LongestShell : num  0.455 0.35 0.53 0.44 0.33 0.425 0.53 0.545 0.475 0.55 ...
 $ Diameter     : num  0.365 0.265 0.42 0.365 0.255 0.3 0.415 0.425 0.37 0.44 ...
 $ Height       : num  0.095 0.09 0.135 0.125 0.08 0.095 0.15 0.125 0.125 0.15 ...
 $ WholeWeight  : num  0.514 0.226 0.677 0.516 0.205 ...
 $ ShuckedWeight: num  0.2245 0.0995 0.2565 0.2155 0.0895 ...
 $ VisceraWeight: num  0.101 0.0485 0.1415 0.114 0.0395 ...
 $ ShellWeight  : num  0.15 0.07 0.21 0.155 0.055 0.12 0.33 0.26 0.165 0.32 ...
 $ Rings        : int  15 7 9 10 7 8 20 16 9 19 ...
head(abalone, 10)                                 # ler as primeiras 10 linhas

4 Análise exploratória dos dados

4.1 Distribuição das variáveis

#pairs((abalone[,1:ncol(abalone)]), panel=panel.smooth)
#ggpairs(abalone)
chart.Correlation(abalone[,2:9], histogram=TRUE, pch=19)

ggplot(abalone, aes(x=Type, fill=Type)) + geom_bar() +
    scale_fill_brewer(palette = "Set1")

4.2 Tabelas de frequências

with(abalone,table(Type))     # tabela freq. p/ variável Type
Type
   F    I    M 
1307 1342 1528 

with(abalone,prop.table(table(Type))*100) # tabela freq. (%) p/ variável Type
Type
       F        I        M 
31.29040 32.12832 36.58128 

with(abalone,table(Rings))    # tabela freq. (%) p/ variável Rings (número de anéis)
Rings
  1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18 
  1   1  15  57 115 259 391 568 689 634 487 267 203 126 103  67  58  42 
 19  20  21  22  23  24  25  26  27  29 
 32  26  14   6   9   2   1   1   2   1 

5 Dividindo conjunto de dados: Treino e Teste

Para avaliar a capacidade preditiva do modelo a ser ajustado, a base foi dividida, aleatoriamente, em duas novas bases: a primeira, com 3000 observações, para o ajuste e demais foram deixadas para avaliação do modelo (base de teste).

train <- abalone[1:3000,]
test <- abalone[(nrow(train)+1):nrow(abalone),]

with(train,table(Rings))    # tabela freq. p/ variável Rings (número de anéis)
Rings
  1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18 
  1   1  12  41  87 185 298 406 498 452 315 196 141  91  77  45  47  30 
 19  20  21  22  23  25  26  27  29 
 24  21  12   6   9   1   1   2   1 
with(test,table(Rings))     # tabela freq. p/ variável Rings (número de anéis)
Rings
  3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18  19  20 
  3  16  28  74  93 162 191 182 172  71  62  35  26  22  11  12   8   5 
 21  24 
  2   2 

6 Ajuste do modelo - Árvore de Regressão (Package Rpart) - Implementa o CART

A variável resposta alvo da modelagem será Rings (número de anéis) do molusco. Para se obter a idade do molusco deve-se somar Rings + 1,5. A função rpart tem o seguinte escopo:

rpart(formula, data, weights, subset, na.action = na.rpart, method, model = FALSE, x = FALSE, y = TRUE, parms, control, cost, …)

O parâmetro control recebe uma lista de opções que controlam detalhes do algoritmo rpart. O escopo geral e os parâmetros passíveis de serem controlados estão detalhados abaixo:

rpart.control(minsplit = 20, minbucket = round(minsplit/3), cp = 0.01, maxcompete = 4, maxsurrogate = 5, usesurrogate = 2, xval = 10, surrogatestyle = 0, maxdepth = 30, …)

minsplit = o número mínimo de observações que devem existir em um nó para que uma divisão seja tentada.

minbucket = o número mínimo de observações em qualquer nó terminal . Se apenas um dos minbucket ou minsplit for especificado, o código define minsplit para minbucket*3 ou minbucket para minsplit/3, conforme apropriado.

cp (parâmetro de complexidade) = Qualquer divisão que não diminua a falta total (SSEpai?) de ajuste por um fator de cp não é tentada. Por exemplo, com anova splitting, isso significa que o R-quadrado total deve aumentar em cp em cada etapa. O principal papel desse parâmetro é economizar tempo de computação removendo as divisões que obviamente não valem a pena. Essencialmente, o usuário informa ao programa que qualquer divisão que não melhore o ajuste por cp provavelmente será eliminada por validação cruzada, e que, portanto, o programa não precisa buscá-la.

maxcompete = o número de divisões do concorrente retidas na saída. É útil saber não apenas qual divisão foi escolhida, mas qual variável veio em segundo, terceiro, etc.

maxsurrogate = o número de divisões substitutas retidas na saída. Se isso for definido como zero, o tempo de cálculo será reduzido, uma vez que aproximadamente metade do tempo computacional (diferente de setup) é usado na busca por splits substitutos.

usesurrogate = como usar substitutos no processo de divisão. 0 significa apenas exibição; uma observação com um valor ausente para a regra de divisão primária não é enviada mais abaixo na árvore. 1 significa usar substitutos, em ordem, para dividir os sujeitos que não têm a variável primária; se todos os substitutos estiverem ausentes, a observação não será dividida. Para o valor 2, se todos os substitutos estiverem ausentes, envie a observação na direção majoritária. Um valor de 0 corresponde à ação da árvore e 2 às recomendações de Breiman et.al (1984).

xval = número de validações cruzadas.

surrogatestyle = controla a seleção de um melhor substituto. Se definido como 0 (padrão), o programa usa o número total de classificações corretas para uma variável substituta em potencial, se definida como 1, usa a porcentagem correta, calculada sobre os valores não ausentes do substituto. A primeira opção penaliza mais severamente as covariáveis com um grande número de valores omissos.

maxdepth = Defina a profundidade máxima de qualquer nó da árvore final, com o nó raiz contado como profundidade 0. Valores maiores que 30 rpart fornecerão resultados sem sentido em máquinas de 32 bits.

Usando cp=1: Um valor de cp = 1 sempre resultará em uma árvore sem divisões

# Um valor de cp = 1 sempre resultará em uma árvore sem divisões
set.seed(1)
(tree.cp1 <- rpart(Rings ~ ., data = train, control = rpart.control(cp = 1)))
n= 3000 

node), split, n, deviance, yval
      * denotes terminal node

1) root 3000 33030.56 9.941 *

Usando os parâmetros default:

# Aqui o valor de cp = 0.01 (padrão)
set.seed(1)
(tree <- rpart(Rings ~ ., data = train, method="anova"))
n= 3000 

node), split, n, deviance, yval
      * denotes terminal node

 1) root 3000 33030.5600  9.941000  
   2) ShellWeight< 0.19475 1248  6526.1660  7.853365  
     4) ShellWeight< 0.06775 333   874.7988  5.939940 *
     5) ShellWeight>=0.06775 915  3988.4870  8.549727  
      10) ShellWeight< 0.11925 367  1046.7740  7.771117 *
      11) ShellWeight>=0.11925 548  2570.2240  9.071168 *
   3) ShellWeight>=0.19475 1752 17190.9400 11.428080  
     6) ShellWeight< 0.4095 1417 11055.2900 10.956250  
      12) ShuckedWeight>=0.39975 834  4041.4240 10.395680  
        24) ShellWeight< 0.29275 307   705.5179  9.397394 *
        25) ShellWeight>=0.29275 527  2851.7270 10.977230 *
      13) ShuckedWeight< 0.39975 583  6376.8990 11.758150  
        26) ShellWeight< 0.25475 321  2310.4740 10.707170  
          52) ShuckedWeight>=0.2425 271  1280.6350 10.195570 *
          53) ShuckedWeight< 0.2425 50   574.4800 13.480000 *
        27) ShellWeight>=0.25475 262  3277.4500 13.045800 *
     7) ShellWeight>=0.4095 335  4485.8090 13.423880  
      14) ShuckedWeight>=0.589 237  2425.2490 12.497890  
        28) ShellWeight< 0.568 188  1250.3190 11.797870 *
        29) ShellWeight>=0.568 49   729.3469 15.183670 *
      15) ShuckedWeight< 0.589 98  1365.8880 15.663270 *

A função print imprime a árvore construída (neste caso, com os parâmetros de controle padrão). Foram geradas 21 regras (nós), que proporcionaram a criação de 10 divisões e 11 nós terminais (leaf). O simbolo * denota um nó terminal. Somente duas variáveis foram escolhidas pelo algoritmo para trabalhar as divisões: ShellWeight e ShuckedWeight.

print(tree)
n= 3000 

node), split, n, deviance, yval
      * denotes terminal node

 1) root 3000 33030.5600  9.941000  
   2) ShellWeight< 0.19475 1248  6526.1660  7.853365  
     4) ShellWeight< 0.06775 333   874.7988  5.939940 *
     5) ShellWeight>=0.06775 915  3988.4870  8.549727  
      10) ShellWeight< 0.11925 367  1046.7740  7.771117 *
      11) ShellWeight>=0.11925 548  2570.2240  9.071168 *
   3) ShellWeight>=0.19475 1752 17190.9400 11.428080  
     6) ShellWeight< 0.4095 1417 11055.2900 10.956250  
      12) ShuckedWeight>=0.39975 834  4041.4240 10.395680  
        24) ShellWeight< 0.29275 307   705.5179  9.397394 *
        25) ShellWeight>=0.29275 527  2851.7270 10.977230 *
      13) ShuckedWeight< 0.39975 583  6376.8990 11.758150  
        26) ShellWeight< 0.25475 321  2310.4740 10.707170  
          52) ShuckedWeight>=0.2425 271  1280.6350 10.195570 *
          53) ShuckedWeight< 0.2425 50   574.4800 13.480000 *
        27) ShellWeight>=0.25475 262  3277.4500 13.045800 *
     7) ShellWeight>=0.4095 335  4485.8090 13.423880  
      14) ShuckedWeight>=0.589 237  2425.2490 12.497890  
        28) ShellWeight< 0.568 188  1250.3190 11.797870 *
        29) ShellWeight>=0.568 49   729.3469 15.183670 *
      15) ShuckedWeight< 0.589 98  1365.8880 15.663270 *

A função summary fornece um resumo amplo do modelo ajustado. A função também reconhece a opção cp que permite ao usuário imprimir apenas poucas divisões superiores. Pode-se, por exemplo, imprimir apenas divisões com cp superiores à 0.02. Na tabela são informados os valores de CP (parâmetro de complexidade), nsplit (número de divisões), rel error (erro relativo), xerror e xstd (erro padrão obtido com base no conjunto de validação). O score CP é impresso da menor árvore (sem divisões) - que terá o maior CP - para a maior e mais complexa (10 divisões), que terá obviamente o menor CP. O número de nós terminais é sempre 1 + o número de divisões (10), portanto têm-se 11 nós terminais. O rel error (erro relativo) é obtido fazendo-se \(1-R^2\) (coeficiente de determinação), semelhante à regressão linear (parece ser as previsões do modelo sobre o conjunto de treinamento). Do contrário, o xerror está relacionado à estatística PRESS (parece ser o erro do modelo sobre o conjunto de validação). O \(R^2\) é calculado fazendo-se SSR/SST, em que: SSR = soma total do erro (valor previsto - valor médio)^2; e SST = soma total do erro (valor real - valor médio)^2. O \(R^2\) também pode ser calculado por 1-SSE /SST. Em que: SSE = soma total do erro (valor real - valor previsto)^2.

A estatística de soma dos quadrados de erros residuais previstos (PRESS) é uma forma de validação cruzada usada na análise de regressão para fornecer uma medida resumida do ajuste de um modelo a uma amostra de observações que não foram usadas para estimar o modelo. É calculado como as somas de quadrados dos resíduos de previsão para essas observações. Tendo sido produzido um modelo ajustado, cada observação, por sua vez, é removida e o modelo é refeito usando as observações restantes. O valor predito fora da amostra é calculado para a observação omitida em cada caso, e a estatística PRESS é calculada como a soma dos quadrados de todos os erros de previsão resultantes.

Tendo sido produzido um modelo ajustado, cada observação, por sua vez, é removida e o modelo é refeito usando as observações restantes. O valor predito fora da amostra é calculado para a observação omitida em cada caso, e a estatística PRESS é calculada como a soma dos quadrados de todos os erros de previsão resultantes.

A primeira divisão parece melhorar consideravelmente o ajuste mais. A regra 1-SE escolheria uma árvore com ? divisões. A melhoria (CP) é a alteração percentual no SSE (soma de erro quadrático) para essa divisão, ou seja, \(1 - (SS_{right} + SS_{left})/SS_{parent}\), que é o ganho em \(R^2\) para o ajuste. (Fonte: An Introduction to Recursive Partitioning Using the RPART Routines) (Therneau et al., 2018, p.37). Assim, por exemplo, o valor CP = 0.281965 para o nó raiz (sem divisão) é obtido fazendo-se: \(1 - (6526.166 + 17190.94)/33030.56 = 0.281965\). Em que: \(SS_{right}\) = soma de erro quadrático na parte direita da árvore; \(SS_{left}\) = soma de erro quadrático na parte esquerda da árvore; \(SS_{parent}\) = soma de erro quadrático no nó “pai”.

Como escolher a melhor árvore?

A convenção é usar a melhor árvore (menor erro de validação cruzada) ou a menor árvore (mais simples) dentro de 1 erro padrão (SE) da melhor árvore (regra 1-SE). A regra 1SE aconselha a procurar a árvore de erro mínimo, mas depois subir 1-SE na busca de uma árvore menos complexa (mais simples). Asssim, para este problema em específico, inicialmente, encontramos que a divisão nsplits = 10 foi aquela com menor erro na validação (xerror = 0.5677094, com xstd = 0.02477702). Em seguida, fazendo-se o cálculo para a regra 1-SE ter-se-ia: \(0.5677094 + 1*0.02477702 = 0.5924864\). Então, o valor obtido pela regra 1-SE foi 0.5924864. Portanto, a regra 1-SE sugere que uma árvore com 7 divisões (xerror = 0.5826542 é menor que o limiar da regra 1-SE, ou seja, está dentro de 1-SE da melhor árvore) faz efetivamente o mesmo trabalho do que a árvore com 10 divisões (que possui o menor xerror = 0.5677094). Por fim, a árvore com 7 divisões pode ser considerada o modelo modelo mais parcimonioso, cujo erro não é mais do que 1-SE (erro padrão) acima do erro do melhor modelo (árvore com 10 divisões).

**Muitas vezes, uma regra de “erro de um padrão” é usada com validação cruzada, na qual escolhemos o modelo mais parcimonioso, cujo erro não é mais do que um erro padrão acima do erro do melhor modelo.“**

summary(tree, cp = 0.02)      # um resumo da árvore (imprimindo divisões com cp superiores à 0.02)
Call:
rpart(formula = Rings ~ ., data = train, method = "anova")
  n= 3000 

           CP nsplit rel error    xerror       xstd
1  0.28196475      0 1.0000000 1.0005429 0.03843041
2  0.05034368      1 0.7180352 0.7275070 0.03044610
3  0.04994897      2 0.6676916 0.6829629 0.02903991
4  0.02158515      3 0.6177426 0.6471234 0.02758708
5  0.02103120      5 0.5745723 0.6295624 0.02703700
6  0.01465854      6 0.5535411 0.5990034 0.02605649
7  0.01378599      7 0.5388826 0.5826542 0.02512203
8  0.01349002      8 0.5250966 0.5817429 0.02510887
9  0.01124683      9 0.5116066 0.5771405 0.02495954
10 0.01000000     10 0.5003597 0.5677094 0.02477702

Variable importance
  ShellWeight   WholeWeight      Diameter  LongestShell VisceraWeight 
           21            17            16            15            15 
ShuckedWeight        Height 
           14             1 

Node number 1: 3000 observations,    complexity param=0.2819648
  mean=9.941, MSE=11.01019 
  left son=2 (1248 obs) right son=3 (1752 obs)
  Primary splits:
      ShellWeight   < 0.19475 to the left,  improve=0.2819648, (0 missing)
      Height        < 0.1225  to the left,  improve=0.2545448, (0 missing)
      WholeWeight   < 0.5665  to the left,  improve=0.2490117, (0 missing)
      VisceraWeight < 0.12075 to the left,  improve=0.2484626, (0 missing)
      Diameter      < 0.3475  to the left,  improve=0.2434915, (0 missing)
  Surrogate splits:
      WholeWeight   < 0.65975 to the left,  agree=0.952, adj=0.885, (0 split)
      Diameter      < 0.4025  to the left,  agree=0.939, adj=0.853, (0 split)
      VisceraWeight < 0.14525 to the left,  agree=0.926, adj=0.821, (0 split)
      LongestShell  < 0.5125  to the left,  agree=0.923, adj=0.814, (0 split)
      ShuckedWeight < 0.27125 to the left,  agree=0.908, adj=0.778, (0 split)

Node number 2: 1248 observations,    complexity param=0.05034368
  mean=7.853365, MSE=5.2293 
  left son=4 (333 obs) right son=5 (915 obs)
  Primary splits:
      ShellWeight  < 0.06775 to the left,  improve=0.2548019, (0 missing)
      Diameter     < 0.2425  to the left,  improve=0.2363086, (0 missing)
      WholeWeight  < 0.19675 to the left,  improve=0.2349016, (0 missing)
      Height       < 0.0975  to the left,  improve=0.2327504, (0 missing)
      LongestShell < 0.3425  to the left,  improve=0.2316758, (0 missing)
  Surrogate splits:
      WholeWeight   < 0.20675 to the left,  agree=0.960, adj=0.850, (0 split)
      Diameter      < 0.2725  to the left,  agree=0.952, adj=0.820, (0 split)
      LongestShell  < 0.3525  to the left,  agree=0.943, adj=0.787, (0 split)
      VisceraWeight < 0.04825 to the left,  agree=0.942, adj=0.781, (0 split)
      ShuckedWeight < 0.08525 to the left,  agree=0.936, adj=0.760, (0 split)

Node number 3: 1752 observations,    complexity param=0.04994897
  mean=11.42808, MSE=9.812179 
  left son=6 (1417 obs) right son=7 (335 obs)
  Primary splits:
      ShellWeight   < 0.4095  to the left,  improve=0.09597162, (0 missing)
      Height        < 0.1675  to the left,  improve=0.06080490, (0 missing)
      Diameter      < 0.5125  to the left,  improve=0.03221533, (0 missing)
      WholeWeight   < 1.48025 to the left,  improve=0.03003377, (0 missing)
      ShuckedWeight < 0.28275 to the right, improve=0.01708752, (0 missing)
  Surrogate splits:
      WholeWeight   < 1.444   to the left,  agree=0.916, adj=0.558, (0 split)
      Diameter      < 0.5275  to the left,  agree=0.902, adj=0.487, (0 split)
      LongestShell  < 0.6725  to the left,  agree=0.891, adj=0.430, (0 split)
      VisceraWeight < 0.35775 to the left,  agree=0.880, adj=0.370, (0 split)
      Height        < 0.1875  to the left,  agree=0.874, adj=0.340, (0 split)

Node number 4: 333 observations
  mean=5.93994, MSE=2.627023 

Node number 5: 915 observations
  mean=8.549727, MSE=4.359003 

Node number 6: 1417 observations,    complexity param=0.02158515
  mean=10.95625, MSE=7.801896 
  left son=12 (834 obs) right son=13 (583 obs)
  Primary splits:
      ShuckedWeight < 0.39975 to the right, improve=0.05761623, (0 missing)
      ShellWeight   < 0.25475 to the left,  improve=0.03744456, (0 missing)
      Height        < 0.1625  to the left,  improve=0.02163832, (0 missing)
      Diameter      < 0.4225  to the right, improve=0.02072119, (0 missing)
      LongestShell  < 0.5325  to the right, improve=0.01690833, (0 missing)
  Surrogate splits:
      WholeWeight   < 0.94775 to the right, agree=0.873, adj=0.691, (0 split)
      LongestShell  < 0.5725  to the right, agree=0.829, adj=0.583, (0 split)
      Diameter      < 0.4525  to the right, agree=0.816, adj=0.552, (0 split)
      VisceraWeight < 0.19425 to the right, agree=0.807, adj=0.530, (0 split)
      ShellWeight   < 0.26325 to the right, agree=0.730, adj=0.345, (0 split)

Node number 7: 335 observations,    complexity param=0.0210312
  mean=13.42388, MSE=13.39047 
  left son=14 (237 obs) right son=15 (98 obs)
  Primary splits:
      ShuckedWeight < 0.589   to the right, improve=0.15486000, (0 missing)
      VisceraWeight < 0.29475 to the right, improve=0.08115390, (0 missing)
      Diameter      < 0.4925  to the right, improve=0.07924906, (0 missing)
      ShellWeight   < 0.568   to the left,  improve=0.07889302, (0 missing)
      WholeWeight   < 1.4395  to the right, improve=0.06980644, (0 missing)
  Surrogate splits:
      WholeWeight   < 1.38925 to the right, agree=0.896, adj=0.643, (0 split)
      VisceraWeight < 0.29825 to the right, agree=0.866, adj=0.541, (0 split)
      LongestShell  < 0.6375  to the right, agree=0.833, adj=0.429, (0 split)
      Diameter      < 0.5175  to the right, agree=0.791, adj=0.286, (0 split)
      ShellWeight   < 0.41175 to the right, agree=0.734, adj=0.092, (0 split)

Node number 12: 834 observations
  mean=10.39568, MSE=4.845833 

Node number 13: 583 observations,    complexity param=0.02158515
  mean=11.75815, MSE=10.93808 
  left son=26 (321 obs) right son=27 (262 obs)
  Primary splits:
      ShellWeight   < 0.25475 to the left,  improve=0.12372390, (0 missing)
      WholeWeight   < 0.942   to the left,  improve=0.09615794, (0 missing)
      Height        < 0.1525  to the left,  improve=0.07877211, (0 missing)
      Type          splits as  RLR,         improve=0.05910621, (0 missing)
      ShuckedWeight < 0.2545  to the right, improve=0.03263195, (0 missing)
  Surrogate splits:
      WholeWeight   < 0.84075 to the left,  agree=0.844, adj=0.653, (0 split)
      Diameter      < 0.4425  to the left,  agree=0.710, adj=0.355, (0 split)
      LongestShell  < 0.5525  to the left,  agree=0.703, adj=0.340, (0 split)
      Height        < 0.1575  to the left,  agree=0.702, adj=0.336, (0 split)
      VisceraWeight < 0.19625 to the left,  agree=0.683, adj=0.294, (0 split)

Node number 14: 237 observations
  mean=12.49789, MSE=10.23312 

Node number 15: 98 observations
  mean=15.66327, MSE=13.93763 

Node number 26: 321 observations
  mean=10.70717, MSE=7.197737 

Node number 27: 262 observations
  mean=13.0458, MSE=12.50935 

O pacote rpart.plot permite gerar árvores customizadas a partir de um objeto rpart:

rpart.plot(tree)

heat.tree <- function(tree, low.is.green = FALSE, ...) { # dots args passed to prp
y <- tree$frame$yval
if(low.is.green)
y <- -y
max <- max(y)
min <- min(y)
cols <- rainbow(99, end = .36)[
ifelse(y > y[1], (y-y[1]) * (99-50) / (max-y[1]) + 50,
(y-min) * (50-1) / (y[1]-min) + 1)]
prp(tree, branch.col = cols, box.col = cols, ...)
}

heat.tree(tree, type = 4, varlen = 0, faclen = 0, fallen.leaves = TRUE)

par(mfrow = c(4,3))
for(iframe in 1:nrow(tree$frame)) {
cols <- ifelse(1:nrow(tree$frame) <= iframe, "black", "gray")
prp(tree, col = cols, branch.col = cols, split.col = cols)
}

Usando a função printcp(tree) pode-se obter a tabela custo-complexidade e outras informações adicionais para o modelo ajustado. Em primeiro lugar, a função relata as variáveis atualmente usuadas para a construção da árvore de regressão: ShellWeight e ShuckedWeight. Também é mostrado o erro no nó raiz, isto é, o erro da árvore sem divisões. Tal erro nada mais é do que o MSE, dado por SSE/N. Onde: SSE = soma de erro quadrático e N = número de observações. Neste primeiro momento, o cálculo do SSE é dado pela soma da diferença entre os valores empíricos de Rings e a média aritmética de Rings em todo cojunto de treinamento.

cp <- printcp(tree)         # Tabela de custo-complexidade.

Regression tree:
rpart(formula = Rings ~ ., data = train, method = "anova")

Variables actually used in tree construction:
[1] ShellWeight   ShuckedWeight

Root node error: 33031/3000 = 11.01

n= 3000 

         CP nsplit rel error  xerror     xstd
1  0.281965      0   1.00000 1.00054 0.038430
2  0.050344      1   0.71804 0.72751 0.030446
3  0.049949      2   0.66769 0.68296 0.029040
4  0.021585      3   0.61774 0.64712 0.027587
5  0.021031      5   0.57457 0.62956 0.027037
6  0.014659      6   0.55354 0.59900 0.026056
7  0.013786      7   0.53888 0.58265 0.025122
8  0.013490      8   0.52510 0.58174 0.025109
9  0.011247      9   0.51161 0.57714 0.024960
10 0.010000     10   0.50036 0.56771 0.024777

rsq.val <- 1-cp[,c(3,4)]    # extrai o Rsquared
#R2 <- 1 - (sum((actual-predict )^2)/sum((actual-mean(actual))^2))

Pode-se plotar o erro de validação cruzada (10 folds) em relação ao parâmetro de complexidade (CP) usando a função plotcp(tree). A função plota no eixo X o parâmetro de complexidade, no eixo Y o erro relativo obtido por validação cruzada e no eixo Z o número de divisões (nsplits) + 1 (estou em dúvida?). As barras verticais representam o resultado da operação (xerror + xstd). Onde: xerror = Erro relativo na validação, xstd = erro padrão na validação. Assim, tem-se que as barras verticais nada mais são do que o erro relativo + 1 desvio padrão, cujos scores são obtidos por meio de valiadação cruzada. A linha pontilhada é o resultado da operação (xerror + xstd) = (0.56771 + 0.024777) = 0.5924864. O valor 0.5924864 é limiar de decisão da regra 1-SE. A idéia é de que árvores menores (menos complexas) que tenham erros de validação (xerror) menor (ou dentro) desse limiar terão desempenho semelhante a árvore com “menor estatística desempenho” na validação cruzada (menor xerror), com a vantagem de ser mais parcimoniosa (mais simples).

# curva custo-complexidade
par(mfrow=c(3,1), mar=c(4,5,5,3)) # mar(bottom, left, top, right)

plotcp(tree, upper = "splits",  minline = TRUE)     # Usando "splits" fica mais fácil ver onde cortar a árvore...          
plotcp(tree, upper = "size",  minline = TRUE)  
plotcp(tree, upper = "none",  minline = TRUE)  

#tree$cptable
cp <- sqrt(tree$cptable[,1] * c(Inf, tree$cptable[,1][-length(tree$cptable[,1])]))
y <- tree$cptable[,3]
plot(order(x=cp, decreasing = T), y=y)

Podemos obter gráficos do ajuste usando a função rsq.rpart(tree). A primeira figura mostra a relação R-squared (coeficiente de determinação) versus Number of Splits (número de divisões). Nesta figura duas curvas são plotadas indicando o R-squared Relativo e R-squared Apparent. O “rel error” é 1−R2, semelhante a regressão linear (creio que seja o erro nas observações usadas para treinar o modelo?). O “xerror” está relacionado à estatística PRESS. Este é o erro nas observações dos dados de validação cruzada. (Fonte: https://stats.stackexchange.com/questions/103018/difference-between-rel-error-and-xerror-in-rpart-regression-trees).

A segunda figura é equivalente a obtida pela função plotcp(tree). A figura sugerem que a árvore deve ser podada (prune) para incluir apenas 9 divisões.

rsq.rpart(tree)

Regression tree:
rpart(formula = Rings ~ ., data = train, method = "anova")

Variables actually used in tree construction:
[1] ShellWeight   ShuckedWeight

Root node error: 33031/3000 = 11.01

n= 3000 

         CP nsplit rel error  xerror     xstd
1  0.281965      0   1.00000 1.00054 0.038430
2  0.050344      1   0.71804 0.72751 0.030446
3  0.049949      2   0.66769 0.68296 0.029040
4  0.021585      3   0.61774 0.64712 0.027587
5  0.021031      5   0.57457 0.62956 0.027037
6  0.014659      6   0.55354 0.59900 0.026056
7  0.013786      7   0.53888 0.58265 0.025122
8  0.013490      8   0.52510 0.58174 0.025109
9  0.011247      9   0.51161 0.57714 0.024960
10 0.010000     10   0.50036 0.56771 0.024777

7 1) Desvendando a árvore de regressão

7.1 a) A árvore sem divisões

Inicialmente, pode-se considerar que Rings (variável resposta y) pode ser explicado pela sua média. Então, utilizando-se dos valores reais de y do conjunto train (n = 3000) pode-se obter a média empírica e soma de erro quadrático (Sum of Squared Errors - SSE). O cálculo do SSE é dado pela soma da diferença entre os valores empíricos de Rings e a média aritmética de Rings. Assim, têm-se:

  • Média aritmética da variável resposta y no conjunto train = 9.941;
  • Score SSE no conjunto train = 33030.557.

Quando o modelo treinado tree é impresso verifica-se que os valores 9.941 e 33030.557 irão compor a raíz da árvore de regressão: (1) root n=3000 SSE=33030.5600 Mean=9.941000. No nó raiz estão 100% dos dados do conjunto train.

setDT(train)
hatTrain <- train[,c("Rings", "ShellWeight", "ShuckedWeight")]                # seleciona colunas
hatTrain[, `:=` (MeanObs = mean(Rings))][]                                    # média aritmética e empilha
hatTrain[, `:=` (SSE = sum((Rings - MeanObs)^2))][]                           # SSE
hatTrain[, `:=` (N=.N)][]                                                     # número de observações
hatTrain[, `:=` (MSE = sum((Rings - MeanObs)^2)/length(Rings))][]             # MSE
hatTrain[, `:=` (R2 = 1-(sum((Rings-MeanObs)^2)/sum((Rings-mean(Rings))^2)))] # R-squared
hatTrain[, `:=` (RelError = 1-R2)]                                            # Relative Error

7.2 b) A primeira decisão de partição da árvore de regressão**

A partir do modelo de árvore gerado pode-se observar que apenas duas variáveis foram utilizadas pelo CART para produzir a árvore de regressão: ShellWeighte ShuckedWeight. A primeira decisão de partição do algoritmo foi: 2) ShellWeight< 0.19475 n=1248 SSE=6526.1660 Mean=7.853365. Assim, a primeira partição dividiu as 3000 observações em grupos de 1248 (42%) e 1752 (58%) (nós 2 e 3) com valores médios de Rings de 7.853365 e 11.428080, respectivamente. Tendo por base essa primeira partição (ShellWeight < 0.19475 e ShellWeight \(\geq\) 0.19475) pode-se avaliar novamente o score SSE considerando-se a média aritmética dos valores empiricos de Rings em cada partição.

Fazendo cálculos simples podemos ratificar os percentuais apresentados nos dois nós folhas iniciais da árvore de regressão:

with(train,table(ShellWeight<0.19475))        # tabela freq. p/ variável Class ShellWeight<0.19475

FALSE  TRUE 
 1752  1248 
with(train,prop.table(table(ShellWeight<0.19475))*100)

FALSE  TRUE 
 58.4  41.6 

Então, pode-se calcular as novas médias aritméticas e SSEs para a primeira partição proposta. Os resultados devem corresponder ao observado em:

2) ShellWeight< 0.19475 n=1248 SSE=6526.1660 Mean=7.853365; e

3) ShellWeight>=0.19475 n=1752 SSE=17190.9400 mean=11.428080.


hatTrain$split.1 <- "Part.1"
hatTrain$split.1[hatTrain$ShellWeight >= tree$splits[1,4]] <- "Part.2"
hatTrain[, `:=` (Mean.1 = mean(Rings)), by="split.1"]                                        # média aritmética por partição e empilha
setorder(hatTrain, split.1)                                                                  # ordenando com base em "binary.1"
hatTrain[, `:=` (SSE.1 = sum((Rings - Mean.1)^2)), by="split.1"][]                           # SSE
hatTrain[, `:=` (N.1=.N), by="split.1"][]                                                    # número de observações por partição
hatTrain[, `:=` (MSE.1 = sum((Rings - Mean.1)^2)/length(Rings)),by="split.1"][]              # MSE
hatTrain[, `:=` (CP.1 = 1 - sum(unique(SSE.1))/unique(SSE))][]                               # CP
hatTrain[, `:=` (R2.1 = 1-(sum((Rings-Mean.1)^2)/sum((Rings-mean(Rings))^2)))]               # R-squared

#1-(SSE.part1 + SSE.part2)/SST                                                        
#hatTrain[, `:=` (SE.2= sum((Rings - Mean.2)^2))][]                 # SSE total primeira partição  
#with(hatTrain,table(Particao.1))

Para melhor compreensão pode-se fazer um scatterplot das duas variáveis utilizadas (ShellWeighte ShuckedWeight) pelo algoritmo para estabelecer as partições:


p <- ggplot() + theme_stata()
(p <- p + geom_point(data=hatTrain, aes(x=ShellWeight, y=ShuckedWeight)))

Aqui, podemos visualizar a primeira partição proposta pelo algoritmo:


p <- ggplot() + theme_stata()
p <- p + geom_point(data=hatTrain, aes(x=ShellWeight, y=ShuckedWeight))
p <- p + geom_vline(xintercept=tree$splits[1,4], colour="red", linetype="dotted")
p <- p + geom_text(aes(x=tree$splits[1,4], y=1.3, label="SheW < 0.19475"), 
            colour="red", angle=90, vjust = 1.2, text=element_text(size=7))
Warning: Ignoring unknown parameters: text
(p <- p + geom_point(data=hatTrain, aes(x=ShellWeight, y=ShuckedWeight, colour=split.1)) + 
    scale_colour_manual(values = c("black", "red")) + theme(legend.position="none"))

7.2.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1

  • Um nova partição

Inicialmente, como já sabe-se á arvore foi dividida em duas parte: ShellWeight< 0.19475 e e ShellWeight \(\geq\) 0.19475. Assim, vamos considerar as partições do lado esquerdo da árvore. Fazendo-se isso, a próxima partição será: 4) ShellWeight< 0.06775 333 874.7988 5.939940*, que indicará um nó terminal. Caso contrário (ShellWeight \(\geq\) 0.06775) uma nova partição foi realizada. Assim, o conjunto de dados foi dividido em duas parte tendo por base ShellWeight < 0.06775, onde 11% (n=333) dos dados tiveram valores < 0.06775 para a variável ShellWeight, e os demais 31% (n=915) tiveram valores superiores. Tendo por base essa nova partição (ShellWeight < 0.06775 e ShellWeight \(\geq\) 0.06775) pode-se avaliar novamente o score SSE considerando-se a média aritmética dos valores empiricos de y em cada partição.

Fazendo cálculos simples podemos ratificar os percentuais apresentados nos dois nós folhas iniciais da árvore de regressão:

with(hatTrain[ which(ShellWeight<0.19475),],table(ShellWeight<0.06775))   # tabela freq. p/ variável Class ShellWeight<0.06775

FALSE  TRUE 
  915   333 

Então, pode-se calcular as novas médias aritméticas e SSEs para a primeira partição proposta. Os resultados devem corresponder ao observado em:

4) ShellWeight< 0.06775 333 874.7988 5.939940 *; e

5) ShellWeight>=0.06775 915 3988.4870 8.549727.


hatTrain$split.2 <- NA

hatTrain$split.2[hatTrain$ShellWeight < 0.19475 & hatTrain$ShellWeight < 0.06775] <- "Part.3"
hatTrain$split.2[hatTrain$ShellWeight < 0.19475 & hatTrain$ShellWeight >= 0.06775] <- "Part.4"
hatTrain$split.2[hatTrain$ShellWeight > 0.19475 & hatTrain$ShellWeight < 0.4095] <- "Part.5"
hatTrain$split.2[hatTrain$ShellWeight > 0.19475 & hatTrain$ShellWeight >= 0.4095] <- "Part.6"

hatTrain[, `:=` (Mean.3 = mean(Rings)), by="split.2"]              # média aritmética por partição e empilha
setorder(hatTrain, split.2)                                        # ordenando com base em "binary.1"
hatTrain[, `:=` (SSE.3 = sum((Rings - Mean.3)^2)), by="split.2"][] # SSE
#with(hatTrain,table(split.2))

Podemos visualizar as novas partições propostas pelo algoritmo:


p <- ggplot() + theme_stata()
p <- p + geom_point(data=hatTrain, aes(x=ShellWeight, y=ShuckedWeight, colour=split.2)) +
  scale_colour_manual(values = c("black", "red", "blue", "green")) + theme(legend.position="none")
p <- p + geom_vline(xintercept=0.19475, colour="red", linetype="dotted")
p <- p + geom_vline(xintercept=0.06775, colour="black", linetype="dotted")
p <- p + geom_vline(xintercept=0.4095, colour="blue", linetype="dotted")
p <- p + geom_text(aes(x=0.19475, y=1.3, label="SheW = 0.19475"), 
            colour="red", angle=90, vjust = 1.2, text=element_text(size=7))
Warning: Ignoring unknown parameters: text
p <- p + geom_text(aes(x=0.06775, y=1.3, label="SheW = 0.06775"), 
            colour="black", angle=90, vjust = 1.2, text=element_text(size=7))
Warning: Ignoring unknown parameters: text
p <- p + geom_text(aes(x=0.4095, y=1.3, label="SheW = 0.4095"), 
            colour="blue", angle=90, vjust = 1.2, text=element_text(size=7))
Warning: Ignoring unknown parameters: text
p